{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Basic Critique-Revise\n",
    "\n",
    "You can also run this notebook online [on Noteable.io](https://app.noteable.io/published/41cf4f55-bd8c-4b7f-b3d4-547d5e601960/basic_critique_revise).\n",
    "\n",
    "This is an basic example of correcting an LLM's output using a pattern called critique-revise, where we highlight what part of the output is wrong and re-query the LLM for a correction.\n",
    "\n",
    "In the below example, we define a Zod schema to match an array of tasks we want the LLM to compose for a given input. Take note that we expect the output tasks to match one of a defined set of types:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "Deno.env.set(\"OPENAI_API_KEY\", \"\");\n",
    "Deno.env.set(\"LANGCHAIN_API_KEY\", \"\");\n",
    "Deno.env.set(\"LANGCHAIN_TRACING_V2\", \"true\");\n",
    "\n",
    "import { z } from \"npm:zod\";\n",
    "\n",
    "const zodSchema = z\n",
    "  .object({\n",
    "    tasks: z\n",
    "      .array(\n",
    "        z.object({\n",
    "          title: z\n",
    "            .string()\n",
    "            .describe(\"The title of the tasks, reminders and alerts\"),\n",
    "          due_date: z\n",
    "            .string()\n",
    "            .describe(\"Due date. Must be a valid ISO date string with timezone\"),\n",
    "          task_type: z\n",
    "            .enum([\n",
    "              \"Call\",\n",
    "              \"Message\",\n",
    "              \"Todo\",\n",
    "              \"In-Person Meeting\",\n",
    "              \"Email\",\n",
    "              \"Mail\",\n",
    "              \"Text\",\n",
    "              \"Open House\",\n",
    "            ])\n",
    "            .describe(\"The type of task\"),\n",
    "        })\n",
    "      )\n",
    "      .describe(\"The JSON for task, reminder or alert to create\"),\n",
    "  })\n",
    "  .describe(\"JSON definition for creating tasks, reminders and alerts\");\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We convert this Zod schema to JSON schema and bind it to an OpenAI model as a function. We also set the `function_call` argument so that the function will always be called and the LLM will always attempt to match its output to the provided schema, regardless of input:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import { zodToJsonSchema } from \"npm:zod-to-json-schema\";\n",
    "\n",
    "import { JsonOutputFunctionsParser } from \"npm:langchain@0.0.173/output_parsers\";\n",
    "import { ChatOpenAI } from \"npm:langchain@0.0.173/chat_models/openai\";\n",
    "import { PromptTemplate } from \"npm:langchain@0.0.173/prompts\";\n",
    "\n",
    "const functionSchema = {\n",
    "  name: \"task-scheduler\",\n",
    "  description: \"Schedules tasks\",\n",
    "  parameters: zodToJsonSchema(zodSchema)\n",
    "};\n",
    "\n",
    "const template = `Respond to the following user query to the best of your ability:\n",
    "\n",
    "{query}`;\n",
    "\n",
    "const generatePrompt = PromptTemplate.fromTemplate(template);\n",
    "\n",
    "const taskFunctionCallModel = new ChatOpenAI({\n",
    "  temperature: 0,\n",
    "  modelName: \"gpt-3.5-turbo\",\n",
    "}).bind({\n",
    "  functions: [functionSchema],\n",
    "  function_call: { name: \"task-scheduler\" },\n",
    "});\n",
    "\n",
    "const generateChain = generatePrompt\n",
    "  .pipe(taskFunctionCallModel)\n",
    "  .pipe(new JsonOutputFunctionsParser()) \n",
    "  .withConfig({ runName: \"GenerateChain\" });"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we invoke the chain on an input. We can see that despite the passed function schema, the model was influenced by the particular user input, and the the outputted task type does not match one of the defined types, and is instead `\"reminder\"`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "  tasks: [\n",
      "    {\n",
      "      title: \"Renew online property ads\",\n",
      "      due_date: \"next week\",\n",
      "      task_type: \"Reminder\"\n",
      "    }\n",
      "  ]\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "import { TraceGroup } from \"npm:langchain@0.0.173/callbacks\";\n",
    "\n",
    "const traceGroup = new TraceGroup(\"CritiqueReviseChain\");\n",
    "const groupManager = await traceGroup.start();\n",
    "\n",
    "const userQuery = `Set a reminder to renew our online property ads next week.`;\n",
    "\n",
    "let result = await generateChain.invoke({\n",
    "  query: userQuery,\n",
    "}, { callbacks: groupManager });\n",
    "\n",
    "console.log(result);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use our original Zod schema's `safeParse` method as a convenient validator to show this. It includes an error object summarizing the problem:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "  \"success\": false,\n",
      "  \"error\": {\n",
      "    \"issues\": [\n",
      "      {\n",
      "        \"received\": \"Reminder\",\n",
      "        \"code\": \"invalid_enum_value\",\n",
      "        \"options\": [\n",
      "          \"Call\",\n",
      "          \"Message\",\n",
      "          \"Todo\",\n",
      "          \"In-Person Meeting\",\n",
      "          \"Email\",\n",
      "          \"Mail\",\n",
      "          \"Text\",\n",
      "          \"Open House\"\n",
      "        ],\n",
      "        \"path\": [\n",
      "          \"tasks\",\n",
      "          0,\n",
      "          \"task_type\"\n",
      "        ],\n",
      "        \"message\": \"Invalid enum value. Expected 'Call' | 'Message' | 'Todo' | 'In-Person Meeting' | 'Email' | 'Mail' | 'Text' | 'Open House', received 'Reminder'\"\n",
      "      }\n",
      "    ],\n",
      "    \"name\": \"ZodError\"\n",
      "  }\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "const outputValidator = (output: unknown) => zodSchema.safeParse(output);\n",
    "\n",
    "let validatorResult = outputValidator(result);\n",
    "\n",
    "console.log(JSON.stringify(validatorResult, null, 2));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we define a revise chain that will attempt to fix the problem. We reuse the previously defined model with the bound function call arguments:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "const reviseTemplate = `Original prompt:\n",
    "--------------\n",
    "{original_prompt}\n",
    "--------------\n",
    "\n",
    "Completion:\n",
    "--------------\n",
    "{completion}\n",
    "--------------\n",
    "\n",
    "Above, the completion did not satisfy the constraints given by the original prompt and provided schema.\n",
    "\n",
    "Error:\n",
    "--------------\n",
    "{error}\n",
    "--------------\n",
    "\n",
    "Try again. Only respond with an answer that satisfies the constraints laid out in the original prompt and provided schema:`;\n",
    "\n",
    "const revisePrompt = PromptTemplate.fromTemplate(reviseTemplate);\n",
    "\n",
    "const reviseChain = revisePrompt\n",
    "  .pipe(taskFunctionCallModel)\n",
    "  .pipe(new JsonOutputFunctionsParser())\n",
    "  .withConfig({ runName: \"ReviseChain\" });"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And finally, we rerun the erroneous output with the original prompt, completion, and error message until it passes the validation successfully:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "  \"success\": true,\n",
      "  \"data\": {\n",
      "    \"tasks\": [\n",
      "      {\n",
      "        \"title\": \"Renew online property ads\",\n",
      "        \"due_date\": \"next week\",\n",
      "        \"task_type\": \"Todo\"\n",
      "      }\n",
      "    ]\n",
      "  }\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "let fixCount = 0;\n",
    "\n",
    "const formattedOriginalPrompt = await generatePrompt.format({\n",
    "  query: userQuery\n",
    "});\n",
    "\n",
    "try {\n",
    "  while (!validatorResult.success && fixCount < 5) {\n",
    "    result = await reviseChain.invoke({\n",
    "      original_prompt: formattedOriginalPrompt,\n",
    "      completion: JSON.stringify(result),\n",
    "      error: JSON.stringify(validatorResult.error),\n",
    "    }, { callbacks: groupManager });\n",
    "    validatorResult = outputValidator(result);\n",
    "    fixCount++;\n",
    "  }\n",
    "} finally {\n",
    "  await traceGroup.end();\n",
    "}\n",
    "\n",
    "console.log(JSON.stringify(validatorResult, null, 2));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can see an example LangSmith trace of the full chain here: https://smith.langchain.com/public/967974dd-ac6c-42a2-8796-da84589d06a5/r"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Deno",
   "language": "typescript",
   "name": "deno"
  },
  "language_info": {
   "file_extension": ".ts",
   "mimetype": "text/x.typescript",
   "name": "typescript",
   "nb_converter": "script",
   "pygments_lexer": "typescript",
   "version": "5.2.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
